from transformers import AutoTokenizer, AutoModelForSequenceClassification
import nltk, numpy as np, torch, os, json
from utils_misc import batcher
from tqdm import tqdm

# 找到在hugging face中的模型标识，总共8个模型
model_map = {
    "snli-base": {"model_card": "boychaboy/SNLI_roberta-base", "entailment_idx": 0, "contradiction_idx": 2},
    "snli-large": {"model_card": "boychaboy/SNLI_roberta-large", "entailment_idx": 0, "contradiction_idx": 2},
    "mnli-base": {"model_card": "microsoft/deberta-base-mnli", "entailment_idx": 2, "contradiction_idx": 0},
    "mnli": {"model_card": "roberta-large-mnli", "entailment_idx": 2, "contradiction_idx": 0},
    "anli": {"model_card": "ynie/roberta-large-snli_mnli_fever_anli_R1_R2_R3-nli", "entailment_idx": 0,
             "contradiction_idx": 2},
    "vitc-base": {"model_card": "tals/albert-base-vitaminc-mnli", "entailment_idx": 0, "contradiction_idx": 1},
    "vitc": {"model_card": "tals/albert-xlarge-vitaminc-mnli", "entailment_idx": 0, "contradiction_idx": 1},
    "vitc-only": {"model_card": "tals/albert-xlarge-vitaminc", "entailment_idx": 0, "contradiction_idx": 1},
    # "decomp": 0,
}


# 根据card找到模型name
def card_to_name(card):
    card2name = {v["model_card"]: k for k, v in model_map.items()}
    if card in card2name:
        return card2name[card]
    return card


# 根据模型name找到模型card
def name_to_card(name):
    if name in model_map:
        return model_map[name]["model_card"]
    return name


def get_neutral_idx(ent_idx, con_idx):
    return list({0, 1, 2} - {ent_idx, con_idx})[0]


def split_sentences(text):
    sentences = []
    tmp = nltk.tokenize.sent_tokenize(text)
    for sentence in tmp:
        if len(sentence) > 10:
            if len(sentence) >= 256:
                sentences.append(sentence[:256])
            sentences.append(sentence)
    return sentences


def split_2sents(text):
    sentences = nltk.tokenize.sent_tokenize(text)
    sentences = [sent for sent in sentences if len(sent) > 10]
    two_sents = [" ".join(sentences[i:(i + 2)]) for i in range(len(sentences))]
    return two_sents


def split_paragraphs(text):
    if text.count("\n\n") > 0:
        paragraphs = [p.strip() for p in text.split("\n\n")]
    else:
        paragraphs = [p.strip() for p in text.split("\n")]
    return [p for p in paragraphs if len(p) > 10]


def split_text(text, granularity="sentence"):
    if granularity == "document":
        return [text]
    elif granularity == "paragraph":
        return split_paragraphs(text)
    elif granularity == "sentence":
        return split_sentences(text)
    elif granularity == "2sents":
        return split_2sents(text)
    elif granularity == "mixed":
        return split_sentences(text) + split_paragraphs(text)


class SummaCImager:
    def __init__(self, model_name="mnli", granularity="paragraph", use_cache=True, max_doc_sents=100, device="cuda",
                 **kwargs):

        self.grans = granularity.split("-")  # 粒度，可以是document或paragraph等，用-分隔

        assert all(gran in ["paragraph", "sentence", "document", "2sents", "mixed"] for gran in self.grans) and len(
            self.grans) <= 2, "Unrecognized `granularity` %s" % (granularity)  # self.grans的长度最大为2
        assert model_name in model_map.keys(), "Unrecognized model name: `%s`" % (model_name)  # 模型名检查

        self.model_name = model_name
        if model_name != "decomp":  # 这里的model_name不会是decomp，因为model_map中decomp已经被注释
            self.model_card = name_to_card(model_name)  # 获取模型的model_card
            # NLI中的句子关系分为三种：1. entailment包含 2. contradiction矛盾 3. neutral中立
            self.entailment_idx = model_map[model_name]["entailment_idx"]
            self.contradiction_idx = model_map[model_name]["contradiction_idx"]
            self.neutral_idx = get_neutral_idx(self.entailment_idx, self.contradiction_idx)

        self.granularity = granularity
        self.use_cache = use_cache
        self.cache_folder = "../export/share/plaban/summac_cache/"

        self.max_doc_sents = max_doc_sents
        self.max_input_length = 500
        self.device = device
        self.cache = {}
        self.model = None  # Lazy loader

    def load_nli(self):
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_card)
        self.model = AutoModelForSequenceClassification.from_pretrained(self.model_card).eval()
        self.model.to(self.device)
        if self.device == "cuda":
            self.model.half()

    def build_chunk_dataset(self, original, generated, pair_idx=None):
        if len(self.grans) == 1:
            gran_doc, gran_sum = self.grans[0], self.grans[0]
        else:
            gran_doc, gran_sum = self.grans[0], self.grans[1]

        original_chunks = split_text(original, granularity=gran_doc)[:self.max_doc_sents]
        generated_chunks = split_text(generated, granularity=gran_sum)

        N_ori, N_gen = len(original_chunks), len(generated_chunks)
        dataset = [{"premise": original_chunks[i], "hypothesis": generated_chunks[j], "doc_i": i, "gen_i": j,
                    "pair_idx": pair_idx} for i in range(N_ori) for j in range(N_gen)]
        return dataset, N_ori, N_gen

    def build_image(self, original, generated):
        cache_key = (original, generated)
        if self.use_cache and cache_key in self.cache:
            cached_image = self.cache[cache_key]
            cached_image = cached_image[:, :self.max_doc_sents, :]
            return cached_image

        dataset, N_ori, N_gen = self.build_chunk_dataset(original, generated)

        if len(dataset) == 0:
            return np.zeros((3, 1, 1))

        image = np.zeros((3, N_ori, N_gen))

        if self.model is None:
            self.load_nli()

        for batch in batcher(dataset, batch_size=20):
            batch_prems = [b["premise"] for b in batch]
            batch_hypos = [b["hypothesis"] for b in batch]
            batch_tokens = self.tokenizer.batch_encode_plus(
                list(zip(batch_prems, batch_hypos)), padding=True, truncation=True,
                max_length=self.max_input_length, return_tensors="pt", truncation_strategy="only_first"
            )
            with torch.no_grad():
                model_outputs = self.model(**{k: v.to(self.device) for k, v in batch_tokens.items()})

            batch_probs = torch.nn.functional.softmax(model_outputs["logits"], dim=-1)
            batch_evids = batch_probs[:, self.entailment_idx].tolist()
            batch_conts = batch_probs[:, self.contradiction_idx].tolist()
            batch_neuts = batch_probs[:, self.neutral_idx].tolist()

            for b, evid, cont, neut in zip(batch, batch_evids, batch_conts, batch_neuts):
                image[0, b["doc_i"], b["gen_i"]] = evid
                image[1, b["doc_i"], b["gen_i"]] = cont
                image[2, b["doc_i"], b["gen_i"]] = neut

        if self.use_cache:
            self.cache[cache_key] = image
        return image

    # originals表示document，generateds表示summary，两者都是列表，一一对应
    def build_images(self, originals, generateds, batch_size=128):
        # 保存不在cache中的originals和generateds
        todo_originals, todo_generateds = [], []
        for ori, gen in zip(originals, generateds):
            cache_key = (ori, gen)
            if cache_key not in self.cache:
                todo_originals.append(ori)
                todo_generateds.append(gen)

        total_dataset = []
        todo_images = []
        for pair_idx, (ori, gen) in enumerate(zip(todo_originals, todo_generateds)):
            dataset, N_ori, N_gen = self.build_chunk_dataset(
                ori, gen, pair_idx=pair_idx
            )  # dataset为一个列表，每个元素是字典，包含premise,hypothesis,doc_i,gen_i,pair_idx，N_ori表示original个数，N_gen表示generated个数
            if len(dataset) == 0:
                image = np.zeros((3, 1, 1))
            else:
                image = np.zeros((3, N_ori, N_gen))
            todo_images.append(image)
            total_dataset += dataset  # total_dataset表示所有dataset的总和
        if len(total_dataset) > 0 and self.model is None:  # self.model还未加载，因为采用的是延迟加载 # Can't just rely on the cache
            self.load_nli()  # 加载model，即NLI模型，并且增加了self.tokenizer

        for batch in batcher(total_dataset, batch_size=batch_size):
            batch_prems = [b["premise"] for b in batch]  # 获取premise前提
            batch_hypos = [b["hypothesis"] for b in batch]  # 获取hypothesis假设
            batch_tokens = self.tokenizer.batch_encode_plus(
                list(zip(batch_prems, batch_hypos)), padding=True, truncation=True,
                max_length=self.max_input_length, return_tensors="pt", truncation_strategy="only_first"
            )
            with torch.no_grad():
                model_outputs = self.model(**{k: v.to(self.device) for k, v in batch_tokens.items()})

            batch_probs = torch.nn.functional.softmax(model_outputs["logits"], dim=-1)
            batch_evids = batch_probs[:, self.entailment_idx].tolist()
            batch_conts = batch_probs[:, self.contradiction_idx].tolist()
            batch_neuts = batch_probs[:, self.neutral_idx].tolist()

            for b, evid, cont, neut in zip(batch, batch_evids, batch_conts, batch_neuts):
                image = todo_images[b["pair_idx"]]
                image[0, b["doc_i"], b["gen_i"]] = evid
                image[1, b["doc_i"], b["gen_i"]] = cont
                image[2, b["doc_i"], b["gen_i"]] = neut

        for pair_idx, (ori, gen) in enumerate(zip(todo_originals, todo_generateds)):
            cache_key = (ori, gen)
            self.cache[cache_key] = todo_images[pair_idx]

        images = [self.cache[(ori, gen)] for ori, gen in zip(originals, generateds)]
        return images

    def get_cache_file(self):
        return os.path.join(self.cache_folder, "cache_%s_%s.json" % (self.model_name, self.granularity))

    def save_cache(self):
        cache_cp = {"[///]".join(k): v.tolist() for k, v in self.cache.items()}
        with open(self.get_cache_file(), "w") as f:
            json.dump(cache_cp, f)

    # 加载缓存，缓存目录默认为'/export/share/plaban/summac_cache/'
    def load_cache(self):
        cache_file = self.get_cache_file()
        if os.path.isfile(cache_file):
            with open(cache_file, "r") as f:
                cache_cp = json.load(f)
                self.cache = {tuple(k.split("[///]")): np.array(v) for k, v in cache_cp.items()}


class SummaCConv(torch.nn.Module):
    def __init__(
            self, models=["mnli", "anli", "vitc"], bins='even50', granularity="sentence", nli_labels="e",
            device="cuda", start_file=None, imager_load_cache=True, agg="mean", **kwargs
    ):
        # `bins` should be `even%d` or `percentiles`
        assert nli_labels in ["e", "c", "n", "ec", "en", "cn", "ecn"], \
            "Unrecognized nli_labels argument %s" % (nli_labels)

        super(SummaCConv, self).__init__()

        self.device = device
        self.models = models

        self.imagers = []
        for model_name in models:
            self.imagers.append(SummaCImager(
                model_name=model_name, granularity=granularity, device=self.device, **kwargs
            ))

        if imager_load_cache:
            for imager in self.imagers:
                imager.load_cache()

        assert len(self.imagers) > 0, "Imager names were empty or unrecognized"

        if "even" in bins:
            n_bins = int(bins.replace("even", ""))
            self.bins = list(np.arange(0, 1, 1 / n_bins)) + [1.0]
        elif bins == "percentile":
            self.bins = [
                0.0, 0.01, 0.02, 0.03, 0.04, 0.07, 0.13, 0.37,
                0.90, 0.91, 0.92, 0.93, 0.94, 0.95, 0.955, 0.96,
                0.965, 0.97, 0.975, 0.98, 0.985, 0.99, 0.995, 1.0
            ]
            # Based on the percentile of the distribution on some large number of summaries

        self.nli_labels = nli_labels
        self.n_bins = len(self.bins) - 1
        self.n_rows = 10
        self.n_labels = 2
        self.n_depth = len(self.imagers) * len(self.nli_labels)
        self.full_size = self.n_depth * self.n_bins

        self.agg = agg

        self.mlp = torch.nn.Linear(self.full_size, 1).to(device)
        self.layer_final = torch.nn.Linear(3, self.n_labels).to(device)

        if start_file == "default":
            start_file = "summac_conv_vitc_sent_perc_e.bin"
            if not os.path.isfile("summac_conv_vitc_sent_perc_e.bin"):
                os.system("wget https://github.com/tingofurro/summac/raw/master/summac_conv_vitc_sent_perc_e.bin")
                assert bins == "percentile", "bins mode should be set to percentile if using the default 1-d convolution weights."
        if start_file is not None:
            print(self.load_state_dict(torch.load(start_file)))

    def build_image(self, original, generated):
        images = [imager.build_image(original, generated) for imager in self.imagers]
        image = np.concatenate(images, axis=0)
        return image

    def compute_histogram(self, original=None, generated=None, image=None):
        # Takes the two texts, and generates a (n_rows, 2*n_bins)

        if image is None:
            image = self.build_image(original, generated)

        N_depth, N_ori, N_gen = image.shape

        full_histogram = []
        for i_gen in range(N_gen):
            histos = []

            for i_depth in range(N_depth):
                if (i_depth % 3 == 0 and "e" in self.nli_labels) or (i_depth % 3 == 1 and "c" in self.nli_labels) or (
                        i_depth % 3 == 2 and "n" in self.nli_labels):
                    histo, X = np.histogram(image[i_depth, :, i_gen], range=(0, 1), bins=self.bins, density=False)
                    histos.append(histo)

            histogram_row = np.concatenate(histos)
            full_histogram.append(histogram_row)

        n_rows_missing = self.n_rows - len(full_histogram)
        full_histogram += [[0.0] * self.full_size] * n_rows_missing
        full_histogram = full_histogram[:self.n_rows]
        full_histogram = np.array(full_histogram)
        return image, full_histogram

    def forward(self, originals, generateds, images=None, silent=True):
        if images is not None:
            # In case they've been pre-computed.
            histograms = []
            for image in images:
                _, histogram = self.compute_histogram(image=image)
                histograms.append(histogram)
        else:
            images, histograms = [], []
            if silent:
                for original, generated in zip(originals, generateds):
                    image, histogram = self.compute_histogram(original=original, generated=generated)
                    images.append(image)
                    histograms.append(histogram)
            else:
                for original, generated in tqdm(zip(originals, generateds), total=len(originals)):
                    image, histogram = self.compute_histogram(original=original, generated=generated)
                    images.append(image)
                    histograms.append(histogram)

        N = len(histograms)
        histograms = torch.FloatTensor(histograms).to(self.device)

        non_zeros = (torch.sum(histograms, dim=-1) != 0.0).long()
        seq_lengths = non_zeros.sum(dim=-1).tolist()

        mlp_outs = self.mlp(histograms).reshape(N, self.n_rows)
        features = []

        for mlp_out, seq_length in zip(mlp_outs, seq_lengths):
            if seq_length > 0:
                Rs = mlp_out[:seq_length]
                if self.agg == "mean":
                    features.append(torch.cat([torch.mean(Rs).unsqueeze(0), torch.mean(Rs).unsqueeze(0),
                                               torch.mean(Rs).unsqueeze(0)]).unsqueeze(0))
                elif self.agg == "min":
                    features.append(torch.cat(
                        [torch.min(Rs).unsqueeze(0), torch.min(Rs).unsqueeze(0), torch.min(Rs).unsqueeze(0)]).unsqueeze(
                        0))
                elif self.agg == "max":
                    features.append(torch.cat(
                        [torch.max(Rs).unsqueeze(0), torch.max(Rs).unsqueeze(0), torch.max(Rs).unsqueeze(0)]).unsqueeze(
                        0))
                elif self.agg == "all":
                    features.append(torch.cat([torch.min(Rs).unsqueeze(0), torch.mean(Rs).unsqueeze(0),
                                               torch.max(Rs).unsqueeze(0)]).unsqueeze(0))
            else:
                features.append(torch.FloatTensor([0.0, 0.0, 0.0]).unsqueeze(0))  # .cuda()
        features = torch.cat(features)
        logits = self.layer_final(features)
        histograms_out = [histogram.cpu().numpy() for histogram in histograms]
        return logits, histograms_out, images

    def save_imager_cache(self):
        for imager in self.imagers:
            imager.save_cache()

    def score(self, originals, generateds, **kwargs):
        with torch.no_grad():
            logits, histograms, images = self.forward(originals, generateds, silent=False)
            probs = torch.nn.functional.softmax(logits, dim=-1)
            batch_scores = probs[:, 1].tolist()
        return {"scores": batch_scores}  # , "histograms": histograms, "images": images


class SummaCZS:
    def __init__(self, model_name="mnli", granularity="paragraph", op1="max", op2="mean", use_ent=True, use_con=True,
                 imager_load_cache=True, device="cuda", **kwargs):
        assert op2 in ["min", "mean", "max"], "Unrecognized `op2`"
        assert op1 in ["max", "mean", "min"], "Unrecognized `op1`"
        self.device = device
        self.imager = SummaCImager(model_name=model_name, granularity=granularity, device=self.device, **kwargs)
        if imager_load_cache:
            self.imager.load_cache()
        self.op2 = op2  # mean
        self.op1 = op1  # max
        self.use_ent = use_ent
        self.use_con = use_con

    def save_imager_cache(self):
        self.imager.save_cache()

    def score_one(self, original, generated):
        image = self.imager.build_image(original, generated)
        score = self.image2score(image)
        return {"image": image, "score": score}

    def image2score(self, image):
        ent_scores = np.max(image[0], axis=0)
        co_scores = np.max(image[1], axis=0)
        if self.op1 == "mean":
            ent_scores = np.mean(image[0], axis=0)
            co_scores = np.mean(image[1], axis=0)
        elif self.op1 == "min":
            ent_scores = np.min(image[0], axis=0)
            co_scores = np.min(image[1], axis=0)

        if self.use_ent and self.use_con:
            scores = ent_scores - co_scores
        elif self.use_ent:
            scores = ent_scores
        elif self.use_con:
            scores = 1.0 - co_scores

        final_score = np.mean(scores)
        if self.op2 == "min":
            final_score = np.min(scores)
        elif self.op2 == "max":
            final_score = np.max(scores)
        return final_score

    # 评估分数，sources表示documents，generated表示summary
    def score(self, sources, generateds, batch_size=128, **kwargs):
        images = self.imager.build_images(sources, generateds, batch_size=batch_size)
        scores = [self.image2score(image) for image in images]
        return {"scores": scores, "images": images}


if __name__ == "__main__":
    # Device can be `cpu` or `cuda` when GPU is available
    model = SummaCZS(granularity="document", model_name="vitc", imager_load_cache=True, device="cpu")

    document = "Jeff joined Microsoft in 1992 to lead corporate developer evangelism for Windows NT."
    summary1 = "Jeff joined Microsoft in 1992."
    summary2 = "Jeff joined Microsoft."

    print(model.score([document, document], [summary1, summary2])["scores"])

    # document = """Jeff joined Microsoft in 1992 to lead corporate developer evangelism for Windows NT. He then served as a Group Program manager in Microsoft's Internet Business Unit. In 1998, he led the creation of SharePoint Portal Server, which became one of Microsoft’s fastest-growing businesses, exceeding $2 billion in revenues. Jeff next served as Corporate Vice President for Program Management across Office 365 Services and Servers, which is the foundation of Microsoft's enterprise cloud leadership. He then led Corporate Strategy supporting Satya Nadella and Amy Hood on Microsoft's mobile-first/cloud-first transformation and acquisitions. Prior to joining Microsoft, Jeff was vice president for software development for an investment firm in New York. He leads Office shared experiences and core applications, as well as OneDrive and SharePoint consumer and business services in Office 365. Jeff holds a Master of Business Administration degree from Harvard Business School and a Bachelor of Science degree in information systems and finance from New York University."""
    # summary = "Jeff joined Microsoft in 1992 to lead the company's corporate evangelism. He then served as a Group Manager in Microsoft's Internet Business Unit. In 1998, Jeff led Sharepoint Portal Server, which became the company's fastest-growing business, surpassing $3 million in revenue. Jeff next leads corporate strategy for SharePoint and Servers which is the basis of Microsoft's cloud-first strategy. He leads corporate strategy for Satya Nadella and Amy Hood on Microsoft's mobile-first."

    # scores = model.score([document], [summary])["images"][0][0].T
    # summary_sentences = model.imager.split_text(summary)

    # print(np.array2string(scores, precision=2))
    # for score_row, sentence in zip(scores, summary_sentences):
    #     print("-----------")
    #     print("[SummaC score: %.3f; supporting sentence: %d] %s " % (np.max(score_row), np.argmax(score_row)+1, sentence))
